Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Nov 25, 2025

Following the plan outlined in #1153, this PR removes the need for a VarInfo when generating a LogDensityFunction, thus allowing us to run models with threaded assume.

Demo (run with 4 threads):

julia> using DynamicPPL, Distributions, ForwardDiff, LogDensityProblems

julia> @model function threaded(N)
           x = Vector{Float64}(undef, N)
           y = Vector{Float64}(undef, N)
           Threads.@threads for i in 1:N
               x[i] ~ Normal()
               y[i] ~ Normal(x[i])
           end
       end
threaded (generic function with 2 methods)

julia> N = 8; model = threaded(N) | (; y = fill(5.0, N))
Model{typeof(threaded), (:N,), (), (), Tuple{Int64}, Tuple{}, DynamicPPL.ConditionContext{@NamedTuple{y::Vector{Float64}}, DefaultContext}}(threaded, (N = 8,), NamedTuple(), ConditionContext((y = [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0],), DefaultContext()))

julia> ldf = LogDensityFunction(model; adtype=AutoForwardDiff());
[ Info: (false, NamedTuple(), Dict{VarName, DynamicPPL.RangeAndLinked}(x[1] => DynamicPPL.RangeAndLinked(1:1, false), x[4] => DynamicPPL.RangeAndLinked(4:4, false), x[7] => DynamicPPL.RangeAndLinked(7:7, false), x[2] => DynamicPPL.RangeAndLinked(2:2, false), x[8] => DynamicPPL.RangeAndLinked(8:8, false), x[5] => DynamicPPL.RangeAndLinked(5:5, false), x[3] => DynamicPPL.RangeAndLinked(3:3, false), x[6] => DynamicPPL.RangeAndLinked(6:6, false)), [-0.5756156037349591, 0.865479608116988, -1.7189260721097455, 0.04400504308742852, -0.8371080934410403, -0.3055849861269748, -1.7506258509801373, 1.0961651139578856])

julia> xs = fill(0.0, N); LogDensityProblems.logdensity_and_gradient(ldf, xs)
(-114.70301653127476, [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0])

julia> N * (logpdf(Normal(), 0.0) + logpdf(Normal(0.0), 5.0)) # expected logpdf
-114.70301653127476

I don't claim that this code is the cleanest (I'm starting to get fed up of the NT/Dict code, and would be very happy to hide it away behind an abstraction). It's probably also very not-performant, but I don't really care because it's a one-time setup cost.

There's also a stray @info. The point of the info is to demonstrate above that the ranges are disjoint, even though there are multiple threads collating RangeAndLinked at the same time. (Obviously, they will in general not be in order because there is no guarantee which index of x is generated first.) This is because TSVI makes the accumulators threadsafe*, and as long as we make sure they don't step on each other's toes when we combine them , the end result will be valid.

* ignoring the threadid indexing issue

I think an atomic accumulator approach (#1137) would certainly be cleaner for this accumulator, but that's orthogonal to the point of this PR.

I tried to use it for NUTS sampling, but the problem is that the NUTS code in Turing itself still needs to generate a VarInfo. That could be refactored along very similar lines to what this PR does, but I didn't do it. In principle, once that is done, it should be 100% possible to sample from this model with NUTS, and also decondition y and do predict on the chain. In effect, pretty much everything should work with threaded assume, except for samplers that require a full VarInfo.

@penelopeysm penelopeysm changed the base branch from main to breaking November 25, 2025 21:47
@github-actions
Copy link
Contributor

github-actions bot commented Nov 25, 2025

Benchmark Report

  • this PR's head: 11319c01aaf4d89010bf73345f7989e7a9e3e2a2
  • base branch: 052bc1950df3a42e14b56eae51af236881092f90

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬───────────────────────────────┬───────────────────────────┬───────────────────────────────┐
│                       │       │             │                   │        │       t(eval) / t(ref)        │     t(grad) / t(eval)     │       t(grad) / t(ref)        │
│                       │       │             │                   │        │ ──────────┬─────────┬──────── │ ──────┬─────────┬──────── │ ──────────┬─────────┬──────── │
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │      base │ this PR │ speedup │  base │ this PR │ speedup │      base │ this PR │ speedup │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼─────────┼─────────┼───────┼─────────┼─────────┼───────────┼─────────┼─────────┤
│               Dynamic │    10 │    mooncake │             typed │   true │    429.26 │     err │     err │  9.96 │     err │     err │   4274.32 │     err │     err │
│                   LDA │    12 │ reversediff │             typed │   true │   2852.37 │     err │     err │  1.91 │     err │     err │   5451.92 │     err │     err │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │ 144651.19 │     err │     err │  5.45 │     err │     err │ 788678.42 │     err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼─────────┼─────────┼───────┼─────────┼─────────┼───────────┼─────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │  13283.35 │     err │     err │  5.94 │     err │     err │  78884.56 │     err │     err │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │  30799.31 │     err │     err │  9.66 │     err │     err │ 297406.24 │     err │     err │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │   3581.30 │     err │     err │  8.74 │     err │     err │  31283.87 │     err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼─────────┼─────────┼───────┼─────────┼─────────┼───────────┼─────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │     16.13 │     err │     err │  1.86 │     err │     err │     30.02 │     err │     err │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │   2431.56 │     err │     err │ 90.38 │     err │     err │ 219772.32 │     err │     err │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │  23125.54 │     err │     err │ 25.29 │     err │     err │ 584877.02 │     err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼─────────┼─────────┼───────┼─────────┼─────────┼───────────┼─────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │   1037.79 │     err │     err │ 76.57 │     err │     err │  79461.16 │     err │     err │
│           Smorgasbord │   201 │      enzyme │             typed │   true │   2538.56 │     err │     err │  4.35 │     err │     err │  11042.72 │     err │     err │
│           Smorgasbord │   201 │    mooncake │             typed │   true │   2506.62 │     err │     err │  5.43 │     err │     err │  13611.78 │     err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼─────────┼─────────┼───────┼─────────┼─────────┼───────────┼─────────┼─────────┤
│           Smorgasbord │   201 │ reversediff │             typed │   true │   2605.10 │     err │     err │ 55.50 │     err │     err │ 144580.71 │     err │     err │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │   2624.64 │     err │     err │ 41.36 │     err │     err │ 108549.66 │     err │     err │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │   2260.80 │     err │     err │ 46.20 │     err │     err │ 104437.64 │     err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼─────────┼─────────┼───────┼─────────┼─────────┼───────────┼─────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │   2308.06 │     err │     err │ 45.08 │     err │     err │ 104053.15 │     err │     err │
│              Submodel │     1 │    mooncake │             typed │   true │     25.24 │     err │     err │  5.29 │     err │     err │    133.50 │     err │     err │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴───────────┴─────────┴─────────┴───────┴─────────┴─────────┴───────────┴─────────┴─────────┘

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants